tensorflow高层API -- Estimator

Estimator

tf.estimator是TensorFlow的高层API,对以下方法进行了封装

  • training: 对应 ModeKeys.TRAIN
  • evaluation: ModeKeys.EVAL
  • predict: ModeKeys.PREDICT
  • export for serving

ModeKeys里这几个对应

update: tf要不支持这个模块了

Importing from tensorflow.python.estimator is unsupported and will soon break!

Usage

有两种使用方式,都需要依赖tf.estimator.Estimator这个类。

pre-made

基于Estimator,官方预定义了一些常用的模型,比如DNNClassifierDNNRegressorLinearClassifierBoostedTreesClassifier等。

以下是一个完整实例: 利用DNNClassifier对花萼分类,其核心代码就几行,充分展现了Estimator封装的简洁性

1
2
3
classifier = tf.estimator.DNNClassifier(features, labels, hidden_units, n_classes)  # 实例化estimator
classifier.train(input_fn_train, steps) # 送入训练数据,开始training
eval_result = classifier.evaluate(input_fn_test) # 送入test data

官方这几个model也为我们使用Estimator提供了例子

custom

官方定义的模型不能够满足我们的需求,那就需要基于Estimator自定义模型

无论是pre-made,还是custom,其核心都是model function
在该方法中需要构建graph,包含training, evaluation, and prediction。

具体实现

Estimator的封装,基本思想是让我们只需关心模型数据,屏蔽硬件(是吗?)。

要实现这种封装,一种思路是抽象类(接口) + 继承 + 覆盖,另一种是完整类 + 传函数。Estimator显然更钟情于第二种方式(why?)。

  • model_fn 定义模型,即构建graph,主要要包含TRAIN、EVAL、PREDICT对应的op。

自己定义model_fn

model_fn传给Estimator

model_fn必须要以Estimator构造函数的方式传递

pre-made

DNNClassifier的实现

1
2
3
4
5
class DNNClassifier(estimator.Estimator): # 定义Estimator的封装,没必要
def __init__(.. # 重写init方法,仅仅是
def _model_fn(.. # 定义model_fn方法 (required)
# 传递model_fn (关键)
super(DNNClassifier, self).__init__(model_fn=_model_fn,..)

这仅仅是对Estimator的一个封装,对外直接用DNNClassifier而无视Estimator的存在。

自己的程序,没必要这样封装,屏蔽掉Estimator而定义一个新类,会让其他程序员看的摸不着头脑。

不建议采用继承的方式,还是开放Estimator比较好吧,即后面的custom方式

custom

input_fn

优势

  • Estimator对底层隔离,兼容CPU、GPU、TPU、多卡、多机多卡 666
  • 其他几个优点没看懂

Estimator

estimator与keras的关系,

Datasets for Estimators

给Estimator传入数据,通常为train何eval分别定义一个input_fn函数,该函数有三个参数:

  • features:字典或DataFrame类型,包含输入数据的特征,与Feature Columns对应
  • labels:标签数组
  • batch_size:batch size
1
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
  • repeat 要迭代数据集多个周期,例如,要创建一个将其输入重复 10 个周期的数据集repeat(10),无参表示无限次地重复输入。
  • shuffle 随机重排输入数据,维持一个固定大小的缓冲区,并从该缓冲区统一地随机选择下一个元素。
  • batch 每次取多少个数据

FAQ

  1. estimator节省了哪些操作?
    • 模型部分并未节省操作
    • 与硬件底层的隔离,免去了写多机多卡的code
  2. estimator能否与low-level API交叉使用?
  3. Estimator如何做多卡?多机并行?如何分配worker和ps?怎样的并行策略?
    • 另外是怎样封装的?
  4. Estimator是屏蔽硬件吗?我们还是可以在model定义时设置device吧?

扩展阅读

  1. https://www.tensorflow.org/guide/premade_estimators
  2. Estimator源码
  3. 实例:
    1. pre-made estimator实例
    2. custom estimator实例